-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable StaticCache for assisted generation #34797
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Matrix YAO <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: Matrix YAO <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@gante , could you pls take a look? thx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yao-matrix hey, gante is currently on a long vacation so I reviewed the PR for him. Thanks for adding support for this, Super cool work!
I left a few comments and also we'll need tests in tests/generation/test_utils.py
file. I guess static cache now works with all types of candidate generators right?
src/transformers/generation/utils.py
Outdated
if assistant_model is not None: | ||
assistant_model._get_cache( | ||
cache_implementation=generation_config.cache_implementation, | ||
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, | ||
max_cache_len=max_cache_length, | ||
device=device, | ||
model_kwargs=model_kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I think it will be called on assistant model when we call assistant.generate()
so there is no need. We can only remove self.generation_config.cache_implementation = None
in candidate generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the thing is: when we leave to let assistant_model.generate
which is in get_candiates
to call this. since the max_new _tokens will be set to max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
when it's first-time called, so the cache_length will be set to int(self.num_assistant_tokens) + prompt_len
, less than the actual needed cache_length max_token_length + prompt_length
, and lead to assert out while generation. So, the key here is assistant model's cache length should be same as main model here. And then I can see this function has assistant_model as an argument but not used it, I think it may be here for the cases like this. That's the rational behind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i see, that makes sense. Then we can leave cache init here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! We need some tests and then I am requesting review from the core maintainer, after that we can merge
@zucchini-nlp , test_utils CI pass rate is the same before and after this PR, as below. So no regressions are introduced. after: |
thx for reviewing. |
@yao-matrix no worries is some tests are failing and are not related to PR changes. Might be just flaky or will be fixed on
|
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@zucchini-nlp , any more comments for me to iterate? Thx. |
@yao-matrix no, the only thing is the CI which is failing now. I showed the relevant test in prev comment and if you can add one more test in At the end you need to run |
@parameterized.expand([(None, True), ("static", False)]) | ||
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache): | ||
if cache_implementation == "static": | ||
self.skipTest("Gemma2 has HybridCache which is not compatible with assisted decoding StaticCache") | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not skip entirely, but only the static_cache
test, as we still need to check if assisted generation works in Gemma2 :)
Maybe it will be skipped by the model._support_static_cache
as I've commented above, but if not we can skip only the test_assisted_decoding_with_num_logits_to_keep_1_static
(maybe it's called a bit differently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i switch to _supports_static_cache
to skip the case. For Gemma, it's a bit different, since it's using HybridCache and claims _supports_static_cache = True
, I still skip it in model test file. Will remove this skip after enable HybridCache for assisted decoding, I plan to enable it after this PR(pure StaticCache) merged, thx.
Signed-off-by: root <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
…ctually Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice, but we need to add a compile
test to make sure this is compile compatible! The whole point of static cache is -> compile! 🤗
@ArthurZucker i added a |
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@ArthurZucker @zucchini-nlp , pls let me know any further comments, thx. BTW, checked the failed ci case, not relevant to my changes. |
Thanks, re-triggered the tests, let's wait for the core maintainer |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker , @zucchini-nlp , I am thinking is it possible we leave this PR in 2024, :). |
@zucchini-nlp @ArthurZucker , any further comments on this? |
@gante , I implemented a version for this issue: #32946. Pls help comment, and I can iterate, thx.